import math
import random
import numpy as np
import util
import csv
import os
import torch.utils.data as Data
from torch.utils.data import Dataset
import torch
from torch import nn as nn
from torch import Tensor
import torch.optim as optim
from tqdm import tqdm
from util import get_project_root, latest_checkpoint, softmax, segment
from sklearn.model_selection import train_test_split
# from net2d.loss_function import LGMLoss_v0
from torchinfo import summary
from torch.nn import TransformerEncoder, TransformerEncoderLayer, TransformerDecoderLayer, TransformerDecoder
from torch.utils.tensorboard import SummaryWriter
from abc import ABCMeta, abstractmethod
from sklearn.metrics import classification_report
from sequence_data import class_dict, Sequence_Data_V1


PROJECT_ROOT = util.get_project_root()


def generate_square_subsequent_mask(sz: int) -> Tensor:
    """Generates an upper-triangular matrix of -inf, with zeros on diag."""
    return torch.triu(torch.ones(sz, sz) * float('-inf'), diagonal=1)


class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: Tensor) -> Tensor:
        """
        Args:
            x: Tensor, shape [seq_len, batch_size, embedding_dim]
        """
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)


class ScoreNet(nn.Module, metaclass=ABCMeta):
    def __init__(self, model_name, input_size, use_cuda):
        super(ScoreNet, self).__init__()
        self.model_name = model_name
        self.input_size = input_size
        self.use_cuda = use_cuda

    def load_trained_model(self, model_path):
        '''
        加载模型
        :param model_file: 模型文件
？》        :return: 网络模型

        '''
        model_file = latest_checkpoint(model_path)

        if model_file is not None:
            print("loading >>> ", model_file, " ...")
            checkpoint = torch.load(model_file)

            if isinstance(checkpoint, dict):
                self.load_state_dict(checkpoint['state_dict'])
        return

    @abstractmethod
    def forward(self, src: Tensor, src_mask: Tensor):
        pass

    def predict(self, x, **kwargs):
        bptt = kwargs["bptt"]
        step = kwargs["step"]

        length = len(x)
        seg_count = length // bptt

        X = np.array(x[:seg_count * bptt], dtype=np.float32)
        tX = torch.from_numpy(X).unsqueeze(dim=0)
        grade_outs, seq_outs, s1_out, s2_out = self.forward(tX, None)

        grade_prob, grade_preds = torch.max(grade_outs.data, 1)
        grade_preds = grade_preds.item()

        # _, raw_seq_preds = torch.max(seq_outs.data, 1)
        # raw_seq_preds = raw_seq_preds.cpu().numpy() - 1 # seq_preds 减1，是为了移动 标注回到 0表示无关作用。在Seq中0为pad
        seq_preds, seq_prob, raw_seq_preds = ScoreNet._get_predictions(seq_outs.squeeze().cpu().numpy())

        seg_preds = ScoreNet._segment(seq_preds, seq_prob, step)
        seq_preds, seq_prob, raw_seq_preds = ScoreNet._restore_predictions(seq_preds, seq_prob, raw_seq_preds, step)

        s1, s2 = self.postprocess(s1_out.item(), s2_out.item(), seq_preds)
        return {"seg_preds": seg_preds,
                "seq_preds": seq_preds, "seq_prob": seq_prob, "raw_seq_preds": raw_seq_preds,
                "grade_preds": grade_preds, "scores": (s1, s2),
                }

    def postprocess(self, s1, s2, seq_preds):
        s1 = round(100 * s1, 1)
        s2 = round(100 * s2, 1)
        if s1 < 0 or s1 > 100:
            s1 = 0
        if s2 < 0 or s2 > 100:
            s2 = 0
        actions = np.unique(seq_preds)
        actions = actions[actions > 0]
        actions_A = np.sum(actions < 8)
        actions_B = np.sum(actions > 7)
        if actions_A == 0:
            s1 = 0
        if actions_B == 0:
            s2 = 0
        return s1, s2

    @staticmethod
    def _get_predictions(features):
        '''
        对预测结果的后处理
        :param features: 基于图像特征的序列
        :return: 对应每一帧图像的各个动作类别的概率
        '''
        K = 24

        pred = np.argmax(features, axis=1) - 1
        local_pred = np.zeros_like(pred)
        pred_len = len(pred)
        for i in range(pred_len):
            start = max(0, i - 2)
            stop = min(i + 3, pred_len)
            temp = pred[start:stop]
            # print(i, temp)
            local_pred[i] = np.argmax(np.bincount(temp))

        local_pred[local_pred >= K] = -1
        prob = softmax(features)

        final_prob = prob[:, 1:K]  # 1 + 7 + 4, 后23个为需要识别的动作， 第一个为BIO流中的O，
        return local_pred, final_prob, pred

    @staticmethod
    def _segment(pred, prob, STEP):
        '''
        对预测的动作标记进行分段
        :param pred: 动作标记序列
        :param prob: 对应的概率
        :return: 分段区间信息
        '''
        start = 0
        pre_label = pred[start]
        nx_label = None
        seg = []
        for i in range(1, len(pred)):
            nx_label = pred[i]
            if pre_label != nx_label:
                seg_prob = prob[start:i, pre_label]
                prob_count = i - start
                prob_std = np.std(seg_prob)
                prob_mean = np.mean(seg_prob)
                # 开始帧数，结果帧数，标注，长度，概率均值，标准差
                seg.append((start * STEP, (i - 1) * STEP, pre_label, prob_count * STEP, prob_mean, prob_std))
                start = i
                pre_label = nx_label

        end = len(pred)
        seg_prob = prob[start:end, nx_label]
        prob_std = np.std(seg_prob)
        prob_mean = np.mean(seg_prob)
        seg.append((start * STEP, end * STEP, nx_label, (end - start) * STEP, prob_mean, prob_std))

        return seg

    @staticmethod
    def _restore_predictions(seq_preds, seq_prob, raw_seq_preds, STEP):
        seg_count = len(seq_preds)
        a = np.reshape(np.arange(0, seg_count, ), (-1, 1))
        b = np.ones((seg_count, STEP))
        c = a * b
        index = np.squeeze(np.reshape(c, (1, -1)), 0).astype(np.int32)
        r_preds = seq_preds[index]
        r_prob = seq_prob[index]
        r_raw_preds = raw_seq_preds[index]
        return r_preds, r_prob, r_raw_preds


class ScoreNet_T5(ScoreNet):
    '''
        [Transformer 2+2] + maxpool + FC, Transformer 2层编码+2层解码，双输出,带基于时序的注意力(使用解码后的)
    '''

    def __init__(self, model_name, input_size, output_size, bptt, step, use_cuda):
        '''
        初始化评分网络
        :param input_size: 输入特征维数
        :param output_size: 输出评分的等级数
        :param use_cuda: 是否使用GPU
        '''
        super(ScoreNet_T5, self).__init__(model_name, input_size, use_cuda)

        trans_input_size = 32
        d_hid = 200  # dimension of the feedforward network model in nn.TransformerEncoder
        nlayers = 2  # number of nn.TransformerEncoderLayer in nn.TransformerEncoder
        nhead = 4  # number of heads in nn.MultiheadAttention
        dropout = 0.2  # dropout probability
        self.pos_encoder = PositionalEncoding(trans_input_size, dropout)
        encoder_layers = TransformerEncoderLayer(trans_input_size, nhead, d_hid, dropout)
        self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)

        self.trans_input_size = trans_input_size
        decoder_layers = TransformerDecoderLayer(trans_input_size, nhead, d_hid, dropout)
        self.transformer_decoder = TransformerDecoder(decoder_layers, nlayers)
        self.pre_encoder = nn.Linear(input_size, trans_input_size)
        self.action_decoder = nn.Linear(trans_input_size, 24)

        # 预测 评分
        score_num_hiddens = (trans_input_size, 128)
        self.decoder = nn.Linear(score_num_hiddens[0], score_num_hiddens[1])
        self.attention = nn.Linear(trans_input_size, score_num_hiddens[1])
        self.attention_s1 = nn.Linear(trans_input_size, score_num_hiddens[1])
        self.attention_s2 = nn.Linear(trans_input_size, score_num_hiddens[1])
        self.maxpool = nn.AdaptiveMaxPool1d(1)
        self.fc = nn.Linear(score_num_hiddens[1], output_size)
        self.softmax = nn.Softmax(dim=1)
        self.bptt = bptt
        self.time_step = step
        self.avgpool = nn.AdaptiveAvgPool1d(1)
        self.decoder_s1 = nn.Linear(score_num_hiddens[0], score_num_hiddens[1])
        self.decoder_s2 = nn.Linear(score_num_hiddens[0], score_num_hiddens[1])
        self.fc_s1 = nn.Linear(score_num_hiddens[1], 1)
        self.fc_s2 = nn.Linear(score_num_hiddens[1], 1)

    def forward(self, src: Tensor, src_mask: Tensor):
        '''
        推理
        :param src: 输入特征
        :param src_mask: 输入mask
        :return: grade_outs评分等级, seq_out图像的动作类别序列
        '''
        bptt = self.bptt
        trans_input_size = self.trans_input_size
        pre_input = self.pre_encoder(src)

        length = src.size(1)
        seg_count = length // bptt
        CX = pre_input.view(bptt, seg_count, trans_input_size)
        CX = CX.permute([1, 0, 2]).contiguous()

        # input = CX.squeeze()
        input = CX
        # Transformer 预测 单帧图像的类别，动作序列的区间
        input = input * math.sqrt(trans_input_size)
        input = self.pos_encoder(input)
        seq_trans_out = self.transformer_encoder(input, mask=None)
        seq_trans_out = self.transformer_decoder(input, seq_trans_out)
        seq_trans_out = seq_trans_out.permute([1, 0, 2]).contiguous()
        seq_fout = seq_trans_out.view(-1, trans_input_size)

        seq_tmp = seq_trans_out.view(-1, self.time_step, trans_input_size)
        seq_tmp = seq_tmp.permute([0, 2, 1])
        seq_avg = self.avgpool(seq_tmp)
        seq_avg = seq_avg.permute([0, 2, 1]).squeeze().contiguous()
        seq_out = self.action_decoder(seq_avg)

        # 计算注意力
        seq_att = torch.tanh(self.attention(seq_fout))
        att_s1 = torch.tanh(self.attention_s1(seq_fout))
        att_s2 = torch.tanh(self.attention_s2(seq_fout))
        # 计算 等级
        x = self.decoder(seq_fout)
        x1 = x * seq_att
        x2 = self.fc(x1)
        x2 = x2.permute([1, 0]).unsqueeze(dim=0)
        x3 = self.maxpool(x2)
        x4 = x3.reshape(x3.size(0), -1)
        grade_outs = self.softmax(x4)

        # 计算得分
        xs1 = self.decoder_s1(seq_fout)
        x1 = xs1 * att_s1
        x2 = self.fc_s1(x1)
        x2 = x2.permute([1, 0]).unsqueeze(dim=0)
        s1 = self.avgpool(x2).squeeze()
        s1 = torch.clip(s1, 0, 1)

        xs2 = self.decoder_s2(seq_fout)
        x1 = xs2 * att_s2
        x2 = self.fc_s2(x1)
        x2 = x2.permute([1, 0]).unsqueeze(dim=0)
        s2 = self.avgpool(x2).squeeze()
        s2 = torch.clip(s2, 0, 1)
        return grade_outs, seq_out, s1, s2


class ScoreNet_T6(ScoreNet):
    '''
        [Transformer 2+2] + maxpool + FC, Transformer 2层编码+2层解码，双输出,带基于时序的注意力(使用解码后的)
    '''

    def __init__(self, model_name, input_size, output_size, trans_input_size, bptt, step, use_cuda):
        '''
        初始化评分网络
        :param input_size: 输入特征维数
        :param output_size: 输出评分的等级数
        :param use_cuda: 是否使用GPU
        '''
        super(ScoreNet_T6, self).__init__(model_name, input_size, use_cuda)

        self.model_name = model_name
        # trans_input_size = 32
        d_hid = 200  # dimension of the feedforward network model in nn.TransformerEncoder
        nlayers = 2  # number of nn.TransformerEncoderLayer in nn.TransformerEncoder
        nhead = 4  # number of heads in nn.MultiheadAttention
        dropout = 0.2  # dropout probability
        self.pos_encoder = PositionalEncoding(trans_input_size, dropout)
        encoder_layers = TransformerEncoderLayer(trans_input_size, nhead, d_hid, dropout)
        self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)

        self.trans_input_size = trans_input_size
        decoder_layers = TransformerDecoderLayer(trans_input_size, nhead, d_hid, dropout)
        self.transformer_decoder = TransformerDecoder(decoder_layers, nlayers)
        self.pre_encoder = nn.Linear(input_size, trans_input_size)
        self.action_decoder = nn.Linear(trans_input_size, 24)

        # 预测 评分
        score_num_hiddens = (trans_input_size, 128)
        self.decoder = nn.Linear(score_num_hiddens[0], score_num_hiddens[1])

        self.attention = nn.Sequential(nn.Linear(trans_input_size, score_num_hiddens[1]),
                                       nn.ReLU(),
                                       nn.Linear(score_num_hiddens[1], 1))
        # self.attention_s1 = nn.Sequential(nn.Linear(trans_input_size, score_num_hiddens[1]),
        #                                nn.ReLU(),
        #                                nn.Linear(score_num_hiddens[1], 1))
        # self.attention_s2 = nn.Sequential(nn.Linear(trans_input_size, score_num_hiddens[1]),
        #                                nn.ReLU(),
        #                                nn.Linear(score_num_hiddens[1], 1))
        # self.attention = nn.Linear(trans_input_size, score_num_hiddens[1])
        self.attention_s1 = nn.Linear(trans_input_size, score_num_hiddens[1])
        self.attention_s2 = nn.Linear(trans_input_size, score_num_hiddens[1])

        self.maxpool = nn.AdaptiveMaxPool1d(1)
        self.fc = nn.Linear(score_num_hiddens[1], output_size)
        self.softmax = nn.Softmax(dim=1)
        self.bptt = bptt
        self.time_step = step
        self.avgpool = nn.AdaptiveAvgPool1d(1)
        self.decoder_s1 = nn.Linear(score_num_hiddens[0], score_num_hiddens[1])
        self.decoder_s2 = nn.Linear(score_num_hiddens[0], score_num_hiddens[1])
        self.fc_s1 = nn.Linear(score_num_hiddens[1], 1)
        self.fc_s2 = nn.Linear(score_num_hiddens[1], 1)

    def forward(self, src: Tensor, src_mask: Tensor):
        '''
        推理
        :param src: 输入特征
        :param src_mask: 输入mask
        :return: grade_outs评分等级, seq_out图像的动作类别序列
        '''
        bptt = self.bptt
        trans_input_size = self.trans_input_size
        pre_input = self.pre_encoder(src)

        length = src.size(1)
        seg_count = length // bptt
        CX = pre_input.view(bptt, seg_count, trans_input_size)
        CX = CX.permute([1, 0, 2]).contiguous()

        # input = CX.squeeze()
        input = CX
        # Transformer 预测 单帧图像的类别，动作序列的区间
        input = input * math.sqrt(trans_input_size)
        input = self.pos_encoder(input)
        seq_trans_out = self.transformer_encoder(input, mask=None)
        seq_trans_out = self.transformer_decoder(input, seq_trans_out)
        seq_trans_out = seq_trans_out.permute([1, 0, 2]).contiguous()
        seq_fout = seq_trans_out.view(-1, trans_input_size)  # N x trans_input_size

        seq_tmp = seq_trans_out.view(-1, self.time_step, trans_input_size)
        seq_tmp = seq_tmp.permute([0, 2, 1])
        seq_avg = self.avgpool(seq_tmp)
        seq_avg = seq_avg.permute([0, 2, 1]).squeeze().contiguous()
        seq_out = self.action_decoder(seq_avg)

        # 计算注意力
        seq_att = self.attention(seq_fout)
        seq_att = torch.softmax(seq_att, dim=0)
        seq_att = seq_att.permute([1, 0])  # 1 x N

        # 计算 等级
        x = self.decoder(seq_fout)
        x1 = torch.mm(seq_att, x)
        x2 = self.fc(x1)
        grade_outs = self.softmax(x2)

        # seq_att = torch.softmax(self.attention(seq_fout), dim=1)
        # x = self.decoder(seq_fout)
        # x1 = x * seq_att
        # x2 = self.fc(x1)
        # x2 = x2.permute([1, 0]).unsqueeze(dim=0)
        # x3 = self.maxpool(x2)
        # x4 = x3.reshape(x3.size(0), -1)
        # grade_outs = self.softmax(x4)

        # att_s1 = self.attention_s1(seq_fout)
        # att_s1 = torch.softmax(att_s1, dim=0)
        # att_s1 = att_s1.permute([1, 0])  # 1 x N
        #
        # att_s2 = self.attention_s2(seq_fout)
        # att_s2 = torch.softmax(att_s2, dim=0)
        # att_s2 = att_s2.permute([1, 0])  # 1 x N
        #
        # # 计算得分
        # xs1 = self.decoder_s1(seq_fout)
        # attn = att_s1# * (1 - att_s2)
        # xs1 = torch.mm(attn, xs1)
        # s1 = self.fc_s1(xs1).squeeze()
        #
        # xs2 = self.decoder_s2(seq_fout)
        # attn = att_s2 #* (1 - att_s1)
        # xs2 = torch.mm(attn, xs2)
        # s2 = self.fc_s1(xs2).squeeze()
        # return grade_outs, seq_out, s1, s2

        # 计算 等级
        # seq_att = torch.tanh(self.attention(seq_fout),)
        #
        # x = self.decoder(seq_fout)
        # x1 = x * seq_att
        # x2 = self.fc(x1)
        # x2 = x2.permute([1, 0]).unsqueeze(dim=0)
        # x3 = self.maxpool(x2)
        # x4 = x3.reshape(x3.size(0), -1)
        # grade_outs = self.softmax(x4)

        att_s1 = torch.tanh(self.attention_s1(seq_fout), )
        att_s2 = torch.tanh(self.attention_s2(seq_fout), )
        # 计算得分
        xs1 = self.decoder_s1(seq_fout)
        x1 = xs1 * att_s1
        x2 = self.fc_s1(x1)
        x2 = x2.permute([1, 0]).unsqueeze(dim=0)
        s1 = self.avgpool(x2).squeeze()

        xs2 = self.decoder_s2(seq_fout)
        x1 = xs2 * att_s2
        x2 = self.fc_s2(x1)
        x2 = x2.permute([1, 0]).unsqueeze(dim=0)
        s2 = self.avgpool(x2).squeeze()
        return grade_outs, seq_out, s1, s2


class Sequence_Data_T5(Dataset):
    '''
        基于图像特征的 序列数据
    '''

    def __init__(self, feature, label, bptt, step):
        '''
        初始化数据集
        :param feature: 由前端网络所的提取的基于图像的特征序列
        :param label: 对应的动作标注序列
        :param bptt: bptt
        '''
        super(Sequence_Data_T5, self).__init__()
        self.feature = feature
        self.label = label
        self.bptt = bptt
        self.step = step
        assert bptt % step == 0, f"bptt({bptt})不能整除step({step})"
        self.fdim = np.size(feature[0], 1)

    def __len__(self):
        return len(self.label)

    def __getitem__(self, i):
        bptt = self.bptt
        X = self.feature[i]
        y, (s1, s2), seq_y = self.label[i]

        length = len(seq_y)
        seg_count = length // bptt

        X = X[:seg_count * bptt]
        tX = torch.from_numpy(X)

        seq_y = seq_y[:seg_count * bptt]
        sy = np.reshape(seq_y, (-1, self.step))
        seq_y = np.zeros((sy.shape[0]))
        for i, item in enumerate(sy):
            d = np.argmax(np.bincount(item, minlength=24))
            seq_y[i] = d
        return tX, y, torch.from_numpy(seq_y, ).long(), seg_count, torch.FloatTensor((s1 / 100, s2 / 100))


class ScoreNet_Trainer_T5():
    '''
        评分模型的训练器
    '''

    def __init__(self, num_classes, use_gpu):
        super(ScoreNet_Trainer_T5, self).__init__()
        self.use_gpu = use_gpu
        self.device = torch.device("cuda:0" if self.use_gpu else "cpu")
        self.num_classes = num_classes

    def load_data(self, feature_path, data_file, bptt, step):
        '''
        加载训练数据
        :param feature: 基于图像的特征序列
        :param label: 对应每帧图像的动作标注
        :param bptt: 序列分段的长度
        :return:
        '''
        data = np.load(data_file, allow_pickle=True)
        data_set = data["data"]
        print("loading ", data_file, ", data size =", len(data_set))
        train, test = train_test_split(data_set, test_size=0.3, shuffle=True, random_state=12)

        self.train_size = len(train)
        self.test_size = len(test)

        train_dst = Sequence_Data_V1(feature_path, train, bptt, step)
        test_dst = Sequence_Data_V1(feature_path, test, bptt, step)

        batch_size = 1  # 只能是1
        self.train_loader = Data.DataLoader(dataset=train_dst, batch_size=batch_size, shuffle=True,
                                            collate_fn=None)
        self.test_loader = Data.DataLoader(dataset=test_dst, batch_size=batch_size, shuffle=False,
                                           collate_fn=None)
        self.batch_size = batch_size

    def load_model(self, model, model_file=None):
        '''
        加载模型
        :param model_file: 模型文件
？》        :return: 网络模型

        '''
        begin_epoch = -1
        optimizer_state, scheduler_state = None, None
        checkpoint_dir = '{}/save_models/{}'.format(PROJECT_ROOT, model.model_name)
        best_acc = 0

        if model_file is None:
            if not os.path.exists(checkpoint_dir):
                os.makedirs(checkpoint_dir)

            model_file = latest_checkpoint(checkpoint_dir)

        if model_file is not None:
            print("loading >>> ", model_file, " ...")
            checkpoint = torch.load(model_file)

            if isinstance(checkpoint, dict):
                model.load_state_dict(checkpoint['state_dict'])

                begin_epoch = checkpoint['epoch']
                if 'optimizer' in checkpoint:
                    optimizer_state = checkpoint['optimizer']
                if 'scheduler' in checkpoint:
                    scheduler_state = checkpoint['scheduler']
                if 'best_acc' in checkpoint:
                    best_acc = checkpoint['best_acc']

        return begin_epoch, optimizer_state, scheduler_state, best_acc

    def transfer_train(self, model, feature_path, data_file, bptt, step, num_epochs):
        '''
        评分模型的训练
        :param num_epochs: 需要训练的轮次
        :return:
        '''
        self.load_data(feature_path, data_file, bptt=bptt, step=step)

        src_mask = generate_square_subsequent_mask(100)
        summary(model, input_data=(torch.randn(1, 100 * bptt, model.input_size), src_mask), device=self.device)

        # Observe that all parameters are being optimized
        optimizer = optim.SGD(model.parameters(), lr=1e-1, momentum=0.9, weight_decay=1e-3)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=10,
                                                               factor=0.1,
                                                               min_lr=1e-4)  # mode为min，则loss不下降学习率乘以factor，max则反之

        begin_epoch, self.optimizer_state, self.scheduler_state, best_acc = self.load_model(model, None)
        self.train_model(model, optimizer, scheduler, begin_epoch, num_epochs, best_acc,
                         back_size=1, assist_enable=False)

    def train_model(self, model, optimizer, scheduler, begin_epoch, num_epochs, best_acc,
                    back_size=1, assist_enable=False):
        '''
        训练模型
        :param model: 网络模型
        :param criterion: loss函数
        :param optimizer: 优化器
        :param scheduler: 学习率调整器
        :param begin_epoch: 当前已经训练的轮次
        :param num_epochs: 需要训练的epoch
        :param best_acc: 历史最佳的准确率
        :param back_size: batch size
        :param assist_enable: 是否使用 中心损失函数
        :return:
        '''
        batch_size = 1

        writer_dir = '{}/reports/{}'.format(PROJECT_ROOT, model.model_name)
        if not os.path.isdir(writer_dir):
            os.mkdir(writer_dir)
        writer = SummaryWriter(writer_dir, flush_secs=15)

        model_root = '{}/save_models/{}'.format(PROJECT_ROOT, model.model_name)
        if not os.path.exists(model_root):
            os.makedirs(model_root)

        # define loss function
        ce_criterion = nn.CrossEntropyLoss()
        mse_criterion = nn.MSELoss()
        if self.use_gpu:
            model.to(self.device)
            ce_criterion.to(self.device)
            mse_criterion.to(self.device)

        if self.optimizer_state is not None:
            optimizer.load_state_dict(self.optimizer_state)
        if self.scheduler_state is not None:
            scheduler.load_state_dict(self.scheduler_state)

        report_data = {}
        for epoch in range(begin_epoch + 1, num_epochs + 1):
            running_loss = 0
            running_correct = 0
            running_correct2 = 0
            running_mse = 0
            seq_count = 0
            lr = optimizer.state_dict()['param_groups'][0]['lr']
            model.train()
            with tqdm(self.train_loader, desc="training in Epoch {}/{}".format(epoch, num_epochs),
                      mininterval=2.0, ) as tq:
                for step, (x, ry, sy, seg_count, s1s2) in enumerate(tq):
                    b_x = x.to(self.device)
                    b_ry = ry.to(self.device)
                    b_sy = sy.to(self.device).squeeze()
                    s1s2 = s1s2.squeeze()
                    b_s1 = s1s2[0].to(self.device)
                    b_s2 = s1s2[1].to(self.device)
                    src_mask = generate_square_subsequent_mask(seg_count).to(self.device)

                    grade_outs, seq_outs, s1_out, s2_out = model(b_x, src_mask)
                    _, preds = torch.max(grade_outs.data, 1)
                    _, seq_preds = torch.max(seq_outs.data, 1)

                    pred_seg_count = torch.sum(torch.abs(torch.diff(seq_preds)) > 0)
                    seq_loss = ce_criterion(seq_outs, b_sy)
                    score_loss = mse_criterion(s1_out, b_s1) + mse_criterion(s2_out, b_s2)
                    loss = ce_criterion(grade_outs, b_ry) + seq_loss + 0.01 * pred_seg_count + 1 * score_loss

                    optimizer.zero_grad()
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
                    optimizer.step()

                    running_loss += loss.item()
                    running_correct += torch.sum(preds == b_ry.data).item()

                    running_correct2 += torch.sum((seq_preds == b_sy.data)).item()
                    seq_count += b_sy.size(0)

                    running_mse += score_loss.item()
                    epoch_mse = score_loss.item() / (step + 1)
                    epoch_loss = running_loss / (step + 1)
                    epoch_acc = running_correct / (step * batch_size + b_x.size(0))
                    epoch_seq_acc = running_correct2 / seq_count

                    tq.set_postfix(Lr=lr, Loss=epoch_loss, seq_loss=seq_loss.item(), train_grade_acc=epoch_acc,
                                   train_seq_acc=epoch_seq_acc, score_mse=epoch_mse)

                scheduler.step(running_loss)
                train_acc = running_correct / self.train_size

            report_data = {"loss": epoch_loss, "acc": epoch_acc, "seq_acc": epoch_seq_acc, "mse": epoch_mse, }
            self.write_report(writer, epoch, report_data, isTrain=True)

            # 每k个epoch测试一次
            if (epoch + 1) % 2 == 0:
                model.eval()

                running_loss = 0
                running_correct = 0
                running_correct2 = 0
                running_mse = 0
                seq_count = 0
                with tqdm(self.test_loader, desc="testing in Epoch {}/{}".format(epoch, num_epochs),
                          mininterval=2.0, colour='green') as tq:
                    for step, (x, ry, sy, seg_count, s1s2) in enumerate(tq):
                        b_x = x.to(self.device)
                        b_ry = ry.to(self.device)
                        b_sy = sy.to(self.device).squeeze()
                        s1s2 = s1s2.squeeze()
                        b_s1 = s1s2[0].to(self.device)
                        b_s2 = s1s2[1].to(self.device)
                        src_mask = generate_square_subsequent_mask(seg_count).to(self.device)

                        grade_outs, seq_outs, s1_out, s2_out = model(b_x, src_mask)
                        _, preds = torch.max(grade_outs.data, 1)
                        _, seq_preds = torch.max(seq_outs.data, 1)

                        pred_seg_count = torch.sum(torch.abs(torch.diff(seq_preds)) > 0)
                        seq_loss = ce_criterion(seq_outs, b_sy)
                        score_loss = mse_criterion(s1_out, b_s1) + mse_criterion(s2_out, b_s2)
                        loss = ce_criterion(grade_outs, b_ry) + seq_loss + 0.01 * pred_seg_count + score_loss

                        running_loss += loss.item()
                        running_correct += torch.sum(preds == b_ry.data).item()

                        running_correct2 += torch.sum((seq_preds == b_sy.data)).item()
                        seq_count += b_sy.size(0)

                        running_mse += score_loss.item()
                        epoch_mse = score_loss.item() / (step + 1)
                        epoch_loss = running_loss / (step + 1)
                        epoch_acc = running_correct / (step * batch_size + b_x.size(0))
                        epoch_seq_acc = running_correct2 / seq_count

                        tq.set_postfix(Lr=lr, Loss=epoch_loss, seq_loss=seq_loss.item(), test_grade_acc=epoch_acc,
                                       test_seq_acc=epoch_seq_acc, score_mse=epoch_mse)

                    test_acc = running_correct / self.test_size

                report_data = {"loss": epoch_loss, "acc": epoch_acc, "seq_acc": epoch_seq_acc, "mse": epoch_mse, }
                self.write_report(writer, epoch, report_data, isTrain=False)

            if (epoch + 1) % 10 == 0:
                # save model
                if best_acc < test_acc:
                    best_acc = test_acc  # min(test_acc, train_acc)
                self.save_model(model, best_acc, epoch, model_root, optimizer, scheduler, test_acc, train_acc)

        writer.close()

    def save_model(self, model, best_acc, epoch, model_root, optimizer, scheduler, test_acc, train_acc):
        save_states = {
            'epoch': epoch,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'scheduler': scheduler.state_dict(),
            'best_acc': best_acc
        }
        torch.save(save_states,
                   "{}/{}_cp-{:04d}-{:.4f}-{:.4f}.pth".format(model_root, model.model_name,
                                                              epoch, train_acc, test_acc), )
        # _use_new_zipfile_serialization=False)  # 兼容老版本

    def write_report(self, writer, epoch, report_data, isTrain=True):
        if isTrain:
            writer.add_scalar('train/loss', report_data['loss'], epoch)
            writer.add_scalar('train/acc', report_data['acc'], epoch)
            writer.add_scalar('train/seq_acc', report_data['seq_acc'], epoch)
            writer.add_scalar('train/mse', report_data['mse'], epoch)
        else:
            writer.add_scalar('test/loss', report_data['loss'], epoch)
            writer.add_scalar('test/acc', report_data['acc'], epoch)
            writer.add_scalar('test/seq_acc', report_data['seq_acc'], epoch)
            writer.add_scalar('test/mse', report_data['mse'], epoch)

    def evaluate(self, model, ):
        # define loss function
        ce_criterion = nn.CrossEntropyLoss()
        l1_criterion = nn.L1Loss()
        if self.use_gpu:
            model.to(self.device)
            ce_criterion.to(self.device)
            l1_criterion.to(self.device)

        self.load_model(model, None)

        model.eval()
        for mode, data_loader in zip(("train", "test"), (self.train_loader, self.test_loader)):
            results = {"preds": [], "seq_preds": [], "score_loss": 0, "labels": [], "seq_labels": []}
            for step, (x, ry, sy, seg_count, s1s2) in enumerate(data_loader):
                b_x = x.to(self.device)
                b_ry = ry.to(self.device)
                b_sy = sy.to(self.device).squeeze()
                s1s2 = s1s2.squeeze()
                b_s1 = s1s2[0].to(self.device)
                b_s2 = s1s2[1].to(self.device)
                src_mask = generate_square_subsequent_mask(seg_count).to(self.device)

                grade_outs, seq_outs, s1_out, s2_out = model(b_x, src_mask)
                _, preds = torch.max(grade_outs.data, 1)
                _, seq_preds = torch.max(seq_outs.data, 1)

                score_loss = l1_criterion(s1_out, b_s1) + l1_criterion(s2_out, b_s2)
                results["score_loss"] += score_loss.item()
                results["preds"].extend(preds.cpu().numpy())
                results["seq_preds"].extend(seq_preds.cpu().numpy())
                results["labels"].extend(b_ry.cpu().numpy())
                results["seq_labels"].extend(b_sy.cpu().numpy())
            results["score_loss"] = results["score_loss"] / len(data_loader)

            print(mode, "score_loss = ", results["score_loss"])
            print(classification_report(results["labels"], results["preds"], zero_division=0))
            print(classification_report(results["seq_labels"], results["seq_preds"], zero_division=0))


def train_res2net50():
    num_classes = 3
    st = ScoreNet_Trainer_T5(num_classes=num_classes, use_gpu=True)
    bptt = 40
    step = 4
    input_size = 2048

    # 前端是迁移学习
    model_name = f"ScoreNet_T6_{input_size}"
    # feature_path = f"{PROJECT_ROOT}/train_data/fps6/res2net50"
    # data_file = f"{PROJECT_ROOT}/train_data/fps6/res2net50_fps6_trans.npz"
    feature_path = "H:/transfomer代码测试/fps6/res2net50"
    data_file = "H:/transfomer代码测试/fps6/res2net50_fps6_trans.npz"
    # model = ScoreNet_T5(model_name=model_name, input_size=2048, output_size=num_classes,
    #                     use_cuda=True, bptt=bptt, step=step)
    model = ScoreNet_T6(model_name=model_name, input_size=2048, output_size=num_classes,
                        use_cuda=True, bptt=bptt, step=step, trans_input_size=32)
    st.transfer_train(model, feature_path, data_file, bptt, step, num_epochs=300)
    st.evaluate(model)


# def train_ucf_res2net18cbam_b():
#     num_classes = 3
#     st = ScoreNet_Trainer_T5(num_classes=num_classes, use_gpu=True)
#     bptt = 40
#     step = 4
#     input_size = 2048
#
#     # 自制主干提取的特征
#     model_name = f"ScoreNet_S5B_{input_size}"
#     feature_path = f"{PROJECT_ROOT}/data/ucf_res2net18cbam_b"
#     model = ScoreNet_T5(model_name=model_name, input_size=2048, output_size=num_classes,
#                         use_cuda=True, bptt=bptt, step=step, trans_input_size=32)
#     st.custom_backbone_train(model, feature_path, bptt, step, num_epochs=300)
#     st.evaluate(model)

def train_resnet2p1d18():
    num_classes = 3
    st = ScoreNet_Trainer_T5(num_classes=num_classes, use_gpu=True)
    bptt = 40
    step = 2
    input_size = 4096

    # 前端是迁移学习
    model_name = f"ScoreNet_3DT6_{input_size}"
    feature_path = f"{PROJECT_ROOT}/train_data/fps12/resnet2p1d18"
    data_file = f"{PROJECT_ROOT}/train_data/fps12/resnet2p1d18_fps12_trans.npz"
    model = ScoreNet_T6(model_name=model_name, input_size=input_size, output_size=num_classes,
                        use_cuda=True, bptt=bptt, step=step, trans_input_size=32)
    st.transfer_train(model, feature_path, data_file, bptt, step, num_epochs=300)
    st.evaluate(model)


if __name__ == '__main__':
    train_res2net50()
    # train_resnet2p1d18()
